from Unet import SimpleUNet
from Config import Config
from torch import load

def load_model(model_path):
    model = SimpleUNet(Config.channels, Config.timesteps, Config.image_size).to(Config.device)
    model.load_state_dict(load(model_path, map_location=Config.device))
    model.eval()
    return model